import math
import torch as _torch
import numpy as _numpy
from numpy.linalg import inv

class numpy:

    @staticmethod
    def make_rotation_matrix(rx=0.0, ry=0.0, rz=0.0, rot_order='xyz', dtype=_numpy.float32):
        cx, cy, cz = math.cos(rx), math.cos(ry), math.cos(rz)
        sx, sy, sz = math.sin(rx), math.sin(ry), math.sin(rz)
        _Rx = _numpy.array([[1.0, 0.0, 0.0], [0.0, cx, -sx], [0.0, sx, cx]], dtype=dtype)
        _Ry = _numpy.array([[cy, 0.0, -sy], [0.0, 1.0, 0.0], [sy, 0.0, cy]], dtype=dtype)
        _Rz = _numpy.array([[cz, -sz, 0.0], [sz, cz, 0.0], [0.0, 0.0, 1.0]], dtype=dtype)
        _R = {'x' : _Rx, 'y' : _Ry, 'z' : _Rz}
        R = _numpy.eye(3, dtype=dtype)
        for i in rot_order:
            R = R @ _R[i]
        return R

    @staticmethod
    def make_constrained_homography(R, tz, K, bv_pivot, pv_pivot):
        K_inv = inv(K)
        n = inv(R) @ _numpy.array([[0.0], [0.0], [1.0]])
        b = K_inv @ _numpy.asarray(bv_pivot).reshape(3, 1)
        p = K_inv @ _numpy.asarray(pv_pivot).reshape(3, 1)
        t2 = _numpy.asarray(tz).reshape((1,))
        t0 = (R[0,2] - b[0]*R[2,2] + p[0]*R[0,0] + p[1]*R[0,1] + b[0]*n[2,0]*t2 - b[0]*p[0]*R[2,0] - b[0]*p[1]*R[2,1] + b[0]*n[0,0]*p[0]*t2 + b[0]*n[1,0]*p[1]*t2)/(n[2,0] + n[0,0]*p[0] + n[1,0]*p[1])
        t1 = (R[1,2] - b[1]*R[2,2] + p[0]*R[1,0] + p[1]*R[1,1] + b[1]*n[2,0]*t2 - b[1]*p[0]*R[2,0] - b[1]*p[1]*R[2,1] + b[1]*n[0,0]*p[0]*t2 + b[1]*n[1,0]*p[1]*t2)/(n[2,0] + n[0,0]*p[0] + n[1,0]*p[1])
        t = _numpy.concatenate((t0, t1, t2))[:, _numpy.newaxis]
        H = K @ (R - t @ _numpy.transpose(n)) @ K_inv
        return H

class torch:

    @staticmethod
    def make_rotation_matrix(rx=0.0, ry=0.0, rz=0.0, rot_order='xyz', dtype=None, device=None):
        if not isinstance(rx, _torch.Tensor):
            rx = _torch.tensor([rx], dtype=dtype, device=device)
        if not isinstance(ry, _torch.Tensor):
            ry = _torch.tensor([ry], dtype=dtype, device=device)
        if not isinstance(rz, _torch.Tensor):
            rz = _torch.tensor([rz], dtype=dtype, device=device)
        cx, cy, cz = _torch.cos(rx), _torch.cos(ry), _torch.cos(rz)
        sx, sy, sz = _torch.sin(rx), _torch.sin(ry), _torch.sin(rz)
        B = rx.shape[0]
        _Rx = _torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], device=device, dtype=dtype).repeat(B, 1, 1)
        _Rx[:, 1, 1], _Rx[:, 1, 2], _Rx[:, 2, 1], _Rx[:, 2, 2] = cx, -sx, sx, cx
        _Ry = _torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], device=device, dtype=dtype).repeat(B, 1, 1)
        _Ry[:, 0, 0], _Ry[:, 0, 2], _Ry[:, 2, 0], _Ry[:, 2, 2] = cy, -sy, sy, cy
        _Rz = _torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], device=device, dtype=dtype).repeat(B, 1, 1)
        _Rz[:, 0, 0], _Rz[:, 0, 1], _Rz[:, 1, 0], _Rz[:, 1, 1] = cz, -sz, sz, cz
        _R = {'x' : _Rx, 'y' : _Ry, 'z' : _Rz}
        R = _torch.eye(3, device=device, dtype=dtype).repeat(B, 1, 1)
        for i in rot_order:
            R = _torch.bmm(R, _R[i])
        return R

    @staticmethod
    def make_constrained_homography(R, tz, K, K_inv, bv_pivot_camera, pv_pivot_camera):
        B, device, dtype = R.shape[0], R.device, R.dtype
        n = _torch.bmm(_torch.inverse(R), _torch.tensor([[[0], [0], [1]]], dtype=dtype, device=device).expand(B, 3, 1))
        b, p, t2 = bv_pivot_camera.view(3, 1), pv_pivot_camera.view(3, 1), tz.expand(B)
        # following two lines of code auto-generated by docs/yield_constrained_t.m
        t0 = (R[:,0,2] - b[0]*R[:,2,2] + p[0]*R[:,0,0] + p[1]*R[:,0,1] + b[0]*n[:,2,0]*t2 - b[0]*p[0]*R[:,2,0] - b[0]*p[1]*R[:,2,1] + b[0]*n[:,0,0]*p[0]*t2 + b[0]*n[:,1,0]*p[1]*t2)/(n[:,2,0] + n[:,0,0]*p[0] + n[:,1,0]*p[1])
        t1 = (R[:,1,2] - b[1]*R[:,2,2] + p[0]*R[:,1,0] + p[1]*R[:,1,1] + b[1]*n[:,2,0]*t2 - b[1]*p[0]*R[:,2,0] - b[1]*p[1]*R[:,2,1] + b[1]*n[:,0,0]*p[0]*t2 + b[1]*n[:,1,0]*p[1]*t2)/(n[:,2,0] + n[:,0,0]*p[0] + n[:,1,0]*p[1])
        t = _torch.cat((t0[:, None, None], t1[:, None, None], t2[:, None, None]), dim=1)
        H = K @ _torch.inverse(R - t @ n.transpose(1, 2)) @ K_inv
        return H